#!/usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from matplotlib import animation
import copy
import time
import rospy
from nav_msgs.msg import Odometry
from geometry_msgs.msg import Twist
import tf
from geometry_msgs.msg import Twist, Quaternion, Vector3

np.random.seed(0)

# rad/s
WHEEL_VEL_MAX = 6.35492
WHEEL_RADIUS = 0.033
WHEEL_BASE = 0.16


def dd_dynamics(x, u):
    """ Diff Drive dd_dynamics
    """
    return np.array([
        (WHEEL_RADIUS / 2.0) * np.cos(x[2, :]) * (u[0, :] + u[1, :]),
        (WHEEL_RADIUS / 2.0) * np.sin(x[2, :]) * (u[0, :] + u[1, :]),
        (WHEEL_RADIUS / WHEEL_BASE) * (u[1, :] - u[0, :])
    ])


def unicycle_dynamics(x, u):
    return np.array(
        [np.cos(x[2, :]) * u[0, :],
         np.sin(x[2, :]) * u[0, :], u[1, :]])


def rk4(x0, u, dt):
    """ Returns position updates for
        given velocity commands after one timestep
        using a Runge-Kutta 4 integrator
    """
    # update calculated here
    k1 = dt * dd_dynamics(x0, u)
    k2 = dt * dd_dynamics(x0 + k1 / 2, u)
    k3 = dt * dd_dynamics(x0 + k2 / 2, u)
    k4 = dt * dd_dynamics(x0 + k3, u)
    # initial plus update
    xnew = x0 + (1.0 / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
    # Clip Theta
    xnew[2, :] = xnew[2, :] - (np.ceil(
        (xnew[2, :] + np.pi) / (2.0 * np.pi)) - 1.0) * 2.0 * np.pi
    return xnew


def euler(x0, u, dt):
    return x0 + dt * unicycle_dynamics(x0, u)


class MPPI:
    def __init__(self, model=rk4, horizon=100, samples=10, thresh=0.05):
        self.horizon = horizon
        self.samples = samples
        self.uvec_init = np.zeros((2, horizon))
        self.model = model
        self.dt = 1.0 / float(horizon)
        # Cost during trajectory
        self.Q = np.array([[1e3, 0.0, 0.0], [0.0, 1e3, 0.0], [0.0, 0.0, 0.0]])
        # Cost of controls
        self.R = np.array([[1.0, 0.0], [0.0, 1.0]])
        # Termination cost
        self.P1 = np.array([[1e3, 0.0, 0.0], [0.0, 1e3, 0.0], [0.0, 0.0, 1e3]])
        self.thresh = thresh
        self.start = np.array([0.0, 0.0, 0.0])
        self.goal = np.array([0.0, 0.0, 0.0])
        self.initialize()

    def initialize(self):
        self.fin_time = [0]  # time stamp of the position/control
        self.latest_uvec = copy.deepcopy(self.uvec_init)
        self.uvec = np.array([self.uvec_init[:, 0]])
        self.path = np.array([self.start])

    def get_path(self,
                 state,
                 goal,
                 sig=np.array([[.9, 0.0], [0.0, .9]]),
                 lam=.001):
        value_fcn, eps = self.get_cost2go(state, self.latest_uvec, goal, lam,
                                          sig)
        self.latest_uvec = self.update_action(self.latest_uvec, eps, value_fcn,
                                              sig, lam)
        state = self.perform_action(state, self.latest_uvec)
        self.path = np.concatenate((self.path, np.array([state])))
        self.uvec = np.concatenate(
            (self.uvec, np.array([self.latest_uvec[:, 0]])))
        self.fin_time.append(self.fin_time[-1] + self.dt)
        # Shift controls for Receding Horizon MPPI
        self.latest_uvec[:, :-1] = self.latest_uvec[:, 1:]
        self.latest_uvec[:, -1] = self.uvec_init[:, 0]
        return state

    def solve_path(self,
                   start,
                   goal,
                   sig=np.array([[1.0, 0.0], [0.0, 1.0]]),
                   lam=.01):
        self.start = start
        self.goal = goal
        state = start
        self.path = np.array([state])
        self.latest_uvec = copy.deepcopy(self.uvec_init)
        print("STARTING")
        i = 0
        tstart = time.time()
        while np.linalg.norm(state[:2] - goal[:2]) > self.thresh:
            i += 1
            state = self.get_path(state, goal, sig, lam)
            if i % 200 == 0:
                print("Iteration: {} \t State: {}".format(i, state))
        t_elapsed = time.time() - tstart
        print(
            "Finished after {} iterations. Final State: {}, Time Taken: {}, Time Per Iter: {}"
            .format(i, state, t_elapsed, t_elapsed / float(i)))

    def get_cost2go(self, state, uvec, goal, lam, sig):
        cost2go = []
        # Dim: statedim x N
        # State at current timestep
        states = np.tile(state, (self.samples, 1)).T

        eps = []
        # Each timed eps contains one disturbance for each timestep
        # so eps is t x udim x N, where N is num of samples
        for t in range(self.horizon):
            # statedim x N, where N is num of samples
            # Sample N states forward in time for N costs to go
            # eps.append(
            #     np.random.multivariate_normal([0.0, 0.0],
            #                                   sig,
            #                                   size=(self.samples)).T)
            eps.append(
                np.random.normal(0,
                                 sig[0, 0],
                                 size=(uvec.shape[0], self.samples)))
            u_samp = np.tile(uvec[:, t], (self.samples, 1)).T
            # EXPLORE
            u_samp += eps[-1]
            # CLIP
            u_samp[0, :] = np.clip(u_samp[0, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)
            u_samp[1, :] = np.clip(u_samp[1, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)
            # Predict Next State
            st = self.model(states, u_samp, self.dt)

            # Trajectory Cost
            costs = np.zeros(self.samples)
            for s in range(self.samples):
                states[:, s] = st[:, s]
                costs[s] = self.get_cost(st[:, s], goal, uvec[:, t], lam,
                                         sig, eps[-1][:, s])

            cost2go.append(costs)

        costs = np.zeros(self.samples)
        for s in range(self.samples):
            # ADD TERMINAL COST
            mxt = (st[:, s] - goal).T.dot(self.P1).dot(st[:, s] - goal)
            costs[s] = mxt

        cost2go[-1] += costs
        # get value function over time as cost2go at each timestep
        # (first reversed and summed)
        # value_fcn = np.cumsum(cost2go[::-1], 0)[::-1, :]
        value_fcn = np.flip(np.cumsum(np.flip(cost2go, 0), axis=0), 0)
        # print("VALUE FCN: {}".format(value_fcn))

        return value_fcn, eps

    def get_cost(self, state, desired_state, u, lam, sig, eps):
        # return euclid dist
        return 0.5 * (
            (state - desired_state).T.dot(self.Q).dot(state - desired_state) +
            u.T.dot(self.R).dot(u)) + lam * u.dot(sig).dot(eps)

    def update_action(self, uvec, eps, value_fcn, sig, lam):
        for t in range(self.horizon):
            # First, subtract the min cost out from each sample
            value_fcn[t] -= np.amin(value_fcn[t])

            # Get weight for optimal action update
            # This tells us the usefulness of each eps sample in [t]
            omg = np.exp(-value_fcn[t] / lam) + 1e-8
            # Normalize
            omg /= np.sum(omg)
            uvec[:, t] += np.dot(eps[t], omg)
        # CLIP
        uvec[0, :] = np.clip(uvec[0, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)
        uvec[1, :] = np.clip(uvec[1, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)

        # filter controls
        uvec = savgol_filter(uvec, self.horizon - 1, 3, axis=1)

        # CLIP
        uvec[0, :] = np.clip(uvec[0, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)
        uvec[1, :] = np.clip(uvec[1, :], -WHEEL_VEL_MAX, WHEEL_VEL_MAX)

        return uvec

    def perform_action(self, state, uvec):
        state = np.tile(state, (2, 1)).T
        u = np.tile(uvec[:, 0], (2, 1)).T
        return self.model(state, u, self.dt)[:, 0]

    def get_animation(self):
        print("generating animation...")
        path = np.array(self.path)  # final path of robot
        control = np.array(self.uvec)  # controls used to generate the path
        fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=(10, 10))
        fig.suptitle("MPPI", fontsize=12)
        ax1.set(xlabel="X position (m)", ylabel="Y position (m)")
        ax1.set_xlim(np.amin(path[:, 0]) - 0.1, np.amax(path[:, 0]) + 0.1)
        ax1.set_ylim(np.amin(path[:, 1]) - 0.1, np.amax(path[:, 1]) + 0.1)
        line11, = ax1.plot([], [])
        ax2.set(xlabel="Time (s)", ylabel="State (m or rad)")
        ax2.set_xlim(0, self.fin_time[-1])  # fin_time is my time vector
        ax2.set_ylim(np.amin(path) - 0.1, np.amax(path) + 0.1)
        line21, = ax2.plot([], [])
        line22, = ax2.plot([], [])
        line23, = ax2.plot([], [])
        ax2.legend(["x", "y", r"$\theta$"])
        ax3.set(xlabel="Time (s)", ylabel="Wheel Velocity (rad/s)")
        ax3.set_xlim(0, self.fin_time[-1])
        ax3.set_ylim(np.amin(control) - 0.1, np.amax(control) + 0.1)
        line31, = ax3.plot([], [])
        line32, = ax3.plot([], [])
        ax3.legend(["u_l", "u_r"])

        def animate(i):
            ax1.clear()
            ax1.set(xlabel="X position (m)", ylabel="Y position (m)")
            ax1.set_xlim(np.amin(path[:, 0]) - 0.1, np.amax(path[:, 0]) + 0.1)
            ax1.set_ylim(np.amin(path[:, 1]) - 0.1, np.amax(path[:, 1]) + 0.1)
            ax1.arrow(x=path[i, 0],
                      y=path[i, 1],
                      dx=0.025 * np.cos(path[i, 2]),
                      dy=0.025 * np.sin(path[i, 2]),
                      head_width=0.03,
                      length_includes_head=True,
                      overhang=0.1,
                      facecolor="black",
                      zorder=0)
            line11, = ax1.plot([], [])
            line11.set_data(path[0:i, 0], path[0:i, 1])
            line21.set_data(self.fin_time[0:i], path[0:i, 0])
            line22.set_data(self.fin_time[0:i], path[0:i, 1])
            line23.set_data(self.fin_time[0:i], path[0:i, 2])
            line31.set_data(self.fin_time[0:i], control[0:i, 0])
            line32.set_data(self.fin_time[0:i], control[0:i, 1])
            return line11, line21, line22, line23, line31, line32

        # may want to change the interval to match your dt, the number is in milliseconds
        anim = animation.FuncAnimation(fig,
                                       animate,
                                       frames=len(self.fin_time),
                                       interval=10,
                                       repeat=False)
        print("saving animation as mp4...this might take a while...")
        anim.save(filename='sim.mp4', fps=30, dpi=300)


def plot_path(path, start, goal):
    plt.figure()
    plt.plot(path[:, 0], path[:, 1], label="Path")
    plt.plot(start[0], start[1], marker='o', markersize=3, label="Start")
    plt.plot(goal[0], goal[1], marker='o', markersize=3, label="Goal")
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(path[:, 0], label="x")
    plt.plot(path[:, 1], label="y")
    plt.plot(path[:, 2], label="theta")
    plt.legend()
    plt.show()


def plot_controls(uvec):
    plt.figure()
    plt.plot(uvec[:, 0], label="u1")
    plt.plot(uvec[:, 1], label="u2")
    plt.legend()
    plt.show()


class Controller:
    def __init__(self):
        self.mppi = MPPI()
        self.sub_pos = rospy.Subscriber('odom',
                                        Odometry,
                                        self.pos_cb,
                                        queue_size=1)
        self.tw_pub = rospy.Publisher('cmd_vel', Twist, queue_size=1)
        self.done = False
        if not rospy.get_param("waypoints"):
            self.parallel_park = True
        else:
            self.parallel_park = False
            self.waypoints = rospy.get_param("waypoints")
        self.idx = 0
        self.init = True
        self.state = self.mppi.start
        tw = Twist()
        tw.linear.x = 0.0
        tw.angular.z = 0.0
        self.tw_pub.publish(tw)
        self.last_time = time.time()

    def wheelsToTwist(self, wheel_vels):
        ul = wheel_vels[0]
        ur = wheel_vels[1]

        vx = WHEEL_RADIUS * (ul + ur) / 2.0
        wz = WHEEL_RADIUS * (-ul + ur) / WHEEL_BASE
        return vx, wz

    def pos_cb(self, odom):
        try:
            # Update Start
            x = odom.pose.pose.position.x
            y = odom.pose.pose.position.y
            orn = odom.pose.pose.orientation
            orn_list = [orn.x, orn.y, orn.z, orn.w]
            theta = tf.transformations.euler_from_quaternion(orn_list)[2]
            self.mppi.start = np.array([x, y, theta])
            if self.parallel_park:
                self.mppi.goal = np.array([0.0, -1.0, 0.0])
            # Now do another round of MPPI
            if np.linalg.norm(self.mppi.start[:2] - self.mppi.goal[:2]
                              ) > self.mppi.thresh and not self.init:
                self.state = self.mppi.get_path(self.mppi.start,
                                                self.mppi.goal)
                self.done = False
            elif self.init:
                self.mppi.initialize()
                if not self.parallel_park:
                    # Goal only contains x,y
                    goal = self.waypoints[self.idx]
                    # Calculate theta and polulate self.goal
                    theta = np.arctan2(goal[1] - self.mppi.start[1],
                                       goal[0] - self.mppi.start[0])
                    self.mppi.goal = np.array([goal[0], goal[1], theta])
                    self.state = self.mppi.start
                self.init = False
                rospy.loginfo("NEXT WAYPOINT: {}".format(self.mppi.goal[:2]))
            else:
                if not self.parallel_park:
                    # Goal Reached, go to next wpt
                    if self.idx + 1 >= len(self.waypoints):
                        self.idx = 0
                    else:
                        self.idx += 1
                    self.mppi.initialize()
                    rospy.loginfo(
                        "WAYPOINT REACHED: {} \n NEXT WAYPOINT: {}".format(
                            self.mppi.goal[:2], self.waypoints[self.idx]))
                    # Goal only contains x,y
                    goal = self.waypoints[self.idx]
                    # Calculate theta and polulate self.goal
                    theta = np.arctan2(goal[1] - self.mppi.start[1],
                                       goal[0] - self.mppi.start[0])
                    self.mppi.goal = np.array([goal[0], goal[1], theta])
                    self.state = self.mppi.start
                else:
                    self.done = True

            # Grab latest controls
            if not self.done:
                u = self.mppi.uvec[-1, :]
            else:
                u = np.array([0.0, 0.0])
            vx, wz = self.wheelsToTwist(u)
            tw = Twist()
            tw.linear.x = vx
            tw.angular.z = wz
            self.tw_pub.publish(tw)

        except rospy.ROSInterruptException:
            pass


def main():
    # """ The main() function. """
    rospy.init_node('mppi', anonymous=True)
    controller = Controller()
    rospy.spin()
    # mppi = MPPI()
    # start = np.array([0.0, 0.0, np.pi / 2.0])
    # goal = np.array([1.0, 0.0, np.pi / 2.0])
    # mppi.solve_path(start, goal)
    # plot_path(mppi.path, start, goal)
    # plot_controls(mppi.uvec)
    # mppi.get_animation()


if __name__ == '__main__':
    main()
